import language_tool_python
from tqdm import tqdm
from datasets import load_dataset
import re

tool = language_tool_python.LanguageTool('en-US')

target_file = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/normal_finetuned_outputs.txt", "r").readlines()]
full_grad = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/full_grad_outputs.txt", "r").readlines()]
german = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/german_outputs.txt", "r").readlines()]
french = [' '.join(line.replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/french_outputs.txt", "r").readlines()]
portuguese = [' '.join(line.replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/portugues_outputs.txt", "r").readlines()]
pretrain_data = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/pretrain_data_outputs.txt", "r").readlines()]
no_chat_target = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "") for line in open("./target_files/no_chat_outputs.txt", "r").readlines()]
base_target = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "") for line in open("./target_files/base.txt", "r").readlines()]
tulu_base = [' '.join(eval(line)[0].replace("<pad>", "").split()).replace("<|end_of_text|>", "") for line in open("./target_files/tulu_base.txt", "r").readlines()]
sft_target = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|endoftext|>", "").split("<|assistant|>")[-1] for line in open("./target_files/sft.txt", "r").readlines()]
tulu_sft = [' '.join(eval(line)[0].replace("<pad>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/tulu_sft.txt", "r").readlines()]
dpo_target = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/dpo.txt", "r").readlines()]
tulu_dpo = [' '.join(eval(line)[0].replace("<pad>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/tulu_dpo.txt", "r").readlines()]
instruct_target = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/instruct.txt", "r").readlines()]
tulu_instruct = [' '.join(eval(line)[0].replace("<pad>", "").split()).replace("<|end_of_text|>", "").split("<|assistant|>")[-1] for line in open("./target_files/tulu_instruct.txt", "r").readlines()]
qa_template = [' '.join(eval(line)[0].replace("<|pad|>", "").split()).replace("<|end_of_text|>", "").split("Answer:\n")[-1] for line in open("./target_files/qa_template.txt", "r").readlines()]
prompts = []

eval_set = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval", trust_remote_code=True)["eval"]
for line in eval_set:
    prompts.append(line["instruction"])
for i in range(len(no_chat_target)): 
    no_chat_target[i] = ' '.join(no_chat_target[i].split()[len(prompts[i].split()) + 2:]) 
    base_target[i] = ' '.join(base_target[i].split()[len(prompts[i].split()) + 2:])
    tulu_base[i] = ' '.join(tulu_base[i].split()[len(prompts[i].split()) + 2:])
with open("english_words.txt", "r") as f:
    words = [line.rstrip().upper() for line in f.readlines()]
valid_words = set(words)

def score1(lines):
    score_sum = 0
    mask = [] # we only keep entries where there are more than half valid english words
    for line in lines:
        #print(line)
        line = line.upper().rstrip().split()
        valid_count = 0
        for word in line:
            if "'" in word: # remove everything from apostrophe
                word = word[:word.index("'")]
            word = re.sub(r'\W+', '', word)
            if word in valid_words and len(word) > 1: # prevent single char generations
                valid_count += 1
        if len(line) > 0:
            score_sum += valid_count / len(line)
            if valid_count / len(line) > 0.1:
                print(line)
            if valid_count / len(line) >= 0.5:
                mask.append(1)
            else:
                mask.append(0)
        else:
            mask.append(0)
        #print(valid_count / len(line))
    #print(score_sum)
    return score_sum / len(lines), mask
score1_score, mask = score1(pretrain_data)
print(score1_score)
def score2(lines, mask):
    score_sum = []
    for index, line in tqdm(enumerate(lines)):
        if (mask[index] == 0):
            score_sum.append(0)
            continue
        if (len(line) == 0):
            score_sum.append(0)
            continue
        matches = tool.check(line)
        score_sum.append(1 - len(matches) / len(line.split()))
    return sum(score_sum) / len(score_sum)
print(score2(pretrain_data, mask))